#!/usr/bin/env python
# -*- coding:utf-8 -*-
# 特征搜车守护进程模块
# 作者：王成
# 日期：2017-04-07
import os,sys,time
import logging
from logging.handlers import TimedRotatingFileHandler,RotatingFileHandler
import yaml
import json
import redis
import requests
import threading
import uuid
import Queue
from daemon import Daemon
from wheezy.http import HTTPResponse
from wheezy.http import WSGIApplication
from wheezy.routing import url
from wheezy.web.handlers import BaseHandler
from wheezy.web.middleware import bootstrap_defaults
from wheezy.web.middleware import path_routing_middleware_factory
from wsgiref.simple_server import make_server

TASKS = {}    
GLOBALS = {}    

concurrency_lock=threading.BoundedSemaphore(value=8)

'''
获取结果
http://127.0.0.1:9001/api?task_id=a25189cba40443dab76d0904d8e1b6bc&action=get
停止任务
http://127.0.0.1:9001/api?task_id=a25189cba40443dab76d0904d8e1b6bc&action=stop
'''

def detect_object(task_id,filters_str,data,sreq,fcgi_url2):
    try:
        result = {'state_code':1,'message':''}
        if 'image_url1' not in data:
            return
        image_url = data['image_url1']
        #logging.warning('task_id: %s 检测目标 [%s]', task_id,image_url)
        if data['vp']:
            x = data['vp']['vehicle_detection'][1]
            y = data['vp']['vehicle_detection'][2]
            w = data['vp']['vehicle_detection'][3]
            h = data['vp']['vehicle_detection'][4]
            try:
                r = sreq.get(image_url, timeout=3)
                img_data = r.content
            except Exception, e: 
                logging.error('task_id: %s 下载图片错误 [%s] : %s', task_id,image_url,str(e))
            else:
                files2 = [('image_file', ('image.jpg', img_data))]
                for filter_str in filters_str:
                    files2.append(('filter_files', ('filter.dat',filter_str)))
                payload = {'crop_x':x,'crop_y':y,'crop_w':w,'crop_h':h}
                try:
                    r = sreq.post(fcgi_url2, files=files2, data=payload)
                    json_result = r.json()
                except Exception, e:
                    logging.error('task_id: %s 分析图片错误 [%s] : %s', task_id,image_url,str(e))       
                else:
                    c,s = 0,0
                    for data_img in json_result['data']:
                        if data_img is None:
                            continue
                        for data_obj in data_img:
                            c+=1
                            s+=data_obj['prob']
                    if c:
                        #logging.warning('task_id: %s 检测到目标 [%d]',task_id, c)
                        lp_type_id = 0
                        if 'license_plates' in data['vp'] and data['vp']['license_plates']:
                            lp_type_id = data['vp']['license_plates'][0]['type_id']
                        veh_type_id = 0
                        if 'veh_attrs' in data['vp'] and data['vp']['veh_attrs']['veh_type_id']:    
                            veh_type_id = data['vp']['veh_attrs']['veh_type_id'][0]
                        result.update({'image_url1':image_url,'location_id':data['location_id'],'direction_id':data['direction_id'],'license_plate':data['license_plate'],'lp_type_id':lp_type_id,'veh_type_id':veh_type_id,'info_id':data['info_id'],'capture_time':data['capture_time'],'label_id':data['vp']['label_id'][0],'prob':s,'count':c,'state_code':2})
        TASKS[task_id]['current']+=1
        result.update({"progress":float(TASKS[task_id]['current'])/float(TASKS[task_id]['total'])})
        TASKS[task_id]['result'].put(result)
    except Exception, e:
        logging.exception('task_id: %s 检测时错误 [%s]', task_id,str(e))   
    finally:
        concurrency_lock.release()
    
def worker(task_id,img_data,car_box,boxes,index_url,post_data,fcgi_url):
    try:
        sreq = requests.Session()

        fcgi_url1 = fcgi_url.replace("vehicle","train_on_one_image")
        fcgi_url2 = fcgi_url.replace("vehicle","detect_object")

        if car_box:
            crop_x,crop_y,crop_w,crop_h = car_box.split(',')
            crop_x,crop_y,crop_w,crop_h = int(crop_x),int(crop_y),int(crop_w),int(crop_h)
        else:
            crop_x,crop_y,crop_w,crop_h = -1,-1,-1,-1
        payload = {'bbox_x':[]
                   ,'bbox_y':[]
                   ,'bbox_w':[]
                   ,'bbox_h':[]
                   ,'crop_x':crop_x
                   ,'crop_y':crop_y
                   ,'crop_w':crop_w
                   ,'crop_h':crop_h
                   }
        for box in boxes:
            payload['bbox_x'].append(int(box['x']))
            payload['bbox_y'].append(int(box['y']))
            payload['bbox_w'].append(int(box['w']))
            payload['bbox_h'].append(int(box['h']))
        #print payload
        logging.warning('task_id: %s 建模中', task_id)    
        files = {'image_file': ('image.jpg', img_data)}
        r = sreq.post(fcgi_url1, files=files, data=payload)
        if r.status_code == requests.codes.ok:
            json_result = r.json()
            if len(json_result['data']) > 0:
                filters_str = json_result['data']
                all_data = []
                
                if post_data['p2']:
                    while 1:
                        r = requests.post(index_url, data=post_data['p2'])
                        if r.status_code == requests.codes.ok:
                            json_data = r.json()
                            #logging.info("p2: %s" % str(json_data))
                            if not json_data['data']:
                                break
                            else:
                                post_data['p2']['o']+=100
                                post_data['p2']['cid'] = json_data['cid']
                                all_data+=json_data['data']
                            #logging.warning('task_id: %s 调用搜索接口2 [%d]', task_id,len(all_data))
                total = len(all_data)          
                if post_data['p3']:            
                    r = requests.post(index_url, data=post_data['p3'])
                    if r.status_code == requests.codes.ok:
                        json_data = r.json()
                        #logging.info("p3:%s" % str(json_data))
                        if json_data['data']:
                            post_data['p3']['o']+=100
                            post_data['p3']['cid'] = json_data['cid']
                            all_data+=json_data['data']
                            #logging.warning('task_id: %s 调用搜索接口3 [%d]', task_id,len(all_data))
                            total+= json_data['total_items']
                TASKS[task_id]['total'] = total
                logging.warning('task_id: %s 开始分析 [%d]', task_id,total)
                while 1:
                    data = {}
                    if len(all_data)>0:
                        data = all_data.pop()
                    else:
                        r = requests.post(index_url, data=post_data['p3'])
                        if r.status_code == requests.codes.ok:
                            json_data = r.json()
                            if not json_data['data']:
                                break
                            else:
                                post_data['p3']['o']+=100
                                post_data['p3']['cid'] = json_data['cid']
                            all_data+=json_data['data']
                            #logging.warning('task_id: %s 调用搜索接口3 [%d]', task_id,len(all_data))
                            data = all_data.pop()                  
                    if not data:
                        break
                    #logging.warning(json.dumps(data))
                    concurrency_lock.acquire()
                    GMT = threading.Thread(target=detect_object, args=(task_id,filters_str,data,sreq,fcgi_url2))#
                    GMT.start()
                    if TASKS[task_id]['last_get']+60 < int(time.time()) or TASKS[task_id]['stop']==1:
                        TASKS[task_id]['stop'] = 1
                        TASKS[task_id]['result'] = Queue.Queue(0)
                        logging.warning('task_id: %s 失效的任务,停止分析', task_id)
                        return
                        
        TASKS[task_id]['stop'] = 1    
        #TASKS[task_id]['result'].put({'state_code':3,'message':''})
        logging.warning('task_id: %s 分析结束', task_id)
        return
    except Exception,e:
        logging.exception('task_id: %s 分析时错误 [%s]', task_id,str(e))

class Feature(BaseHandler):
    def get(self):#0:建模中 1:分析中 2:一个分析结果 3:分析结束 9:异常
        response = HTTPResponse(content_type='application/json; charset=UTF-8')
        task_id = self.request.get_param("task_id")
        action = self.request.get_param("action")
        if task_id in TASKS:
            #logging.warning('task_id: %s 查看任务', task_id)
            if action=='get':
                if TASKS[task_id]['stop']==1 and TASKS[task_id]['result'].qsize()==0:
                    response.write(json.dumps({'state_code':3,'message':''}))
                    return response
                TASKS[task_id]['last_get'] = int(time.time())
                r = {'state_code':0,'message':''}
                try:
                    while TASKS[task_id]['result'].qsize()>0:
                        result = TASKS[task_id]['result'].get(timeout=1)
                        r.update(result)
                        if 'info_id' in result:
                            r.update({'state_code':2})
                            break
                        else:
                            r.update({'state_code':1})
                            continue
                except Queue.Empty:
                    pass
                except Exception, e:
                    r.update({'state_code':9,'message':str(e)})
                    logging.exception('task_id: %s 获取分析结果错误', task_id)
                response.write(json.dumps(r))
                return response                
            elif action=='stop':
                logging.warning('task_id: %s 用户停止任务', task_id)
                TASKS[task_id]['stop'] = 1
                response.write(json.dumps({'state_code':3,'message':''}))
                return response                  
        else:
            response.write(json.dumps({'state_code':9,'message':'invalid task_id'}))
            return response             

    def post(self):
        response = HTTPResponse(content_type='application/json; charset=UTF-8')
        car_box = self.request.form.get("car_box")
        boxes = self.request.form.get("boxes")
        index_url = self.request.form.get("index_url")
        post_data = self.request.form.get("post_data")
        fcgi_url = self.request.form.get("fcgi_url")
        files = self.request.files.get('image_file')
        
        res = {'state_code':0,'message':''}
        if car_box and boxes and index_url and post_data and fcgi_url and files:
            car_box = car_box[0]
            boxes = json.loads(boxes[0])
            index_url = index_url[0]
            post_data = json.loads(post_data[0])
            fcgi_url = fcgi_url[0]
            img_data = files[0].value
            r = requests.post(index_url, data=post_data['p1'])
            del post_data['p1']
            if r.status_code == requests.codes.ok:
                json_r = r.json()
                total = json_r['total_items']
                task_id = uuid.uuid4().hex
                logging.warning('task_id: %s 新任务 [%d]', task_id,total)
                if not total:
                    total = 0
                    TASKS[task_id] = {"stop":1,"total":total,"result":Queue.Queue(0),"last_get":int(time.time()),"create_time":int(time.time()),'current':0}
                if total>0:
                    TASKS[task_id] = {"stop":0,"total":total,"result":Queue.Queue(0),"last_get":int(time.time()),"create_time":int(time.time()),'current':0}  
                    GMT = threading.Thread(target=worker, args=(task_id,img_data,car_box,boxes,index_url,post_data,fcgi_url))
                    GMT.start()
                    for d,x in TASKS.items():
                        if x['create_time']+36000<int(time.time()):
                            del TASKS[d]
                res.update({'task_id':task_id,'total':total})
            else:
                res.update({'state_code':9,'message':'index error'})
                logging.error('task_id: %s 搜索接口异常', task_id)
        else:
            res.update({'state_code':9,'message':'param error'})
        response.write(json.dumps(res))
        return response             

def get_redis(request):#实时追踪接口
    task_id = request.get_param("task_id")
    res = '{}'
    if task_id:
        GLOBALS['redis_slow'].expire("ZZ:"+task_id,60)
        r = GLOBALS['redis_slow'].lpop('ZZ:'+task_id)
        if r is None:
            pass
        elif r=='':
            GLOBALS['redis_slow'].rpush('ZZ:'+task_id,'')
            GLOBALS['redis_slow'].expire('ZZ:'+task_id,60)
        else:
            res = r
            GLOBALS['redis_slow'].expire('ZZ:'+task_id,60)
    response = HTTPResponse(content_type='application/json; charset=UTF-8')
    response.write(res)
    return response

class MyDaemon(Daemon):
    def run(self):
        all_urls = [
            url('api', Feature, name='default'),
            url('get_redis', get_redis, name='welcome')
        ]
        options = {}
        main = WSGIApplication(
            middleware=[
                bootstrap_defaults(url_mapping=all_urls),
                path_routing_middleware_factory
            ],
            options=options
        )
        
        try:
            config_file = open(os.path.dirname(os.path.abspath(__file__)) + '/config.yaml')
            config = yaml.safe_load(config_file)
            config_file.close()
            
            name = 'yisa_feature_search'
            logging.basicConfig(level=logging.INFO)
            handler = RotatingFileHandler('/var/log/%s.log' % name, maxBytes=134217728, backupCount=7)
            formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
            handler.setFormatter(formatter)
            logging.getLogger('').addHandler(handler)
            #-------------------同步输出到控制台-------------------
            # console = logging.StreamHandler()
            # console.setLevel(logging.INFO)
            # formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
            # console.setFormatter(formatter)
            # logging.getLogger().addHandler(console)
            #-------------------------------------------------------            
            GLOBALS['redis_slow'] = redis.StrictRedis(unix_socket_path=config['redis_slow']['unix_socket_path']) 
            
            logging.warning('启动 [%s]', name)        
            host = config['feature_search']['host']  
            port = config['feature_search']['port']
            concurrency_lock=threading.BoundedSemaphore(value=config['feature_search']['thread_num'])
            logging.warning('监听 [%s:%s]', host,port)
            
            make_server(host, int(port), main).serve_forever()
        except KeyboardInterrupt:
            sys.exit(1)
        except Exception, e:
            logging.exception('启动 [%s] 错误: %s' , name,str(e))
            sys.exit(1)   
    
if __name__ == "__main__": 
    daemon = MyDaemon('/var/run/yisa_feature_search.pid')
    #daemon.run()
    if len(sys.argv) == 2:
        if 'start' == sys.argv[1]:
            daemon.start()
        elif 'stop' == sys.argv[1]:
            daemon.stop()
        elif 'restart' == sys.argv[1]:
            daemon.restart()
        else:
            daemon.run()
            print "Unknown command"
            sys.exit(2)
        sys.exit(0)
    else:
        print "usage: %s start|stop|restart" % sys.argv[0]
        sys.exit(2)
